{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorBoard 可视化\n",
    "[github](https://github.com/lanpa/tensorboard-pytorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-24T09:39:39.910789Z",
     "start_time": "2017-12-24T09:39:39.398570Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "from torchvision.datasets import CIFAR10\n",
    "from utils import resnet\n",
    "from torchvision import transforms as tfs\n",
    "from datetime import datetime\n",
    "from tensorboardX import SummaryWriter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-24T09:39:41.981293Z",
     "start_time": "2017-12-24T09:39:40.621895Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 使用数据增强\n",
    "def train_tf(x):\n",
    "    im_aug = tfs.Compose([\n",
    "        tfs.Resize(120),\n",
    "        tfs.RandomHorizontalFlip(),\n",
    "        tfs.RandomCrop(96),\n",
    "        tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),\n",
    "        tfs.ToTensor(),\n",
    "        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
    "    ])\n",
    "    x = im_aug(x)\n",
    "    return x\n",
    "\n",
    "def test_tf(x):\n",
    "    im_aug = tfs.Compose([\n",
    "        tfs.Resize(96),\n",
    "        tfs.ToTensor(),\n",
    "        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
    "    ])\n",
    "    x = im_aug(x)\n",
    "    return x\n",
    "\n",
    "train_set = CIFAR10('./data', train=True, transform=train_tf)\n",
    "train_data = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)\n",
    "valid_set = CIFAR10('./data', train=False, transform=test_tf)\n",
    "valid_data = torch.utils.data.DataLoader(valid_set, batch_size=256, shuffle=False, num_workers=4)\n",
    "\n",
    "net = resnet(3, 10)\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-24T09:53:40.434024Z",
     "start_time": "2017-12-24T09:39:41.984480Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Train Loss: 1.877906, Train Acc: 0.315410, Valid Loss: 2.198587, Valid Acc: 0.293164, Time 00:00:26\n",
      "Epoch 1. Train Loss: 1.398501, Train Acc: 0.498657, Valid Loss: 1.877540, Valid Acc: 0.400098, Time 00:00:27\n",
      "Epoch 2. Train Loss: 1.141419, Train Acc: 0.597628, Valid Loss: 1.872355, Valid Acc: 0.446777, Time 00:00:27\n",
      "Epoch 3. Train Loss: 0.980048, Train Acc: 0.658367, Valid Loss: 1.672951, Valid Acc: 0.475391, Time 00:00:27\n",
      "Epoch 4. Train Loss: 0.871448, Train Acc: 0.695073, Valid Loss: 1.263234, Valid Acc: 0.578613, Time 00:00:28\n",
      "Epoch 5. Train Loss: 0.794649, Train Acc: 0.723992, Valid Loss: 2.142715, Valid Acc: 0.466699, Time 00:00:27\n",
      "Epoch 6. Train Loss: 0.736611, Train Acc: 0.741554, Valid Loss: 1.701331, Valid Acc: 0.500391, Time 00:00:27\n",
      "Epoch 7. Train Loss: 0.695095, Train Acc: 0.756816, Valid Loss: 1.385478, Valid Acc: 0.597656, Time 00:00:28\n",
      "Epoch 8. Train Loss: 0.652659, Train Acc: 0.773796, Valid Loss: 1.029726, Valid Acc: 0.676465, Time 00:00:27\n",
      "Epoch 9. Train Loss: 0.623829, Train Acc: 0.784144, Valid Loss: 0.933388, Valid Acc: 0.682520, Time 00:00:27\n",
      "Epoch 10. Train Loss: 0.581615, Train Acc: 0.798792, Valid Loss: 1.291557, Valid Acc: 0.635938, Time 00:00:27\n",
      "Epoch 11. Train Loss: 0.559358, Train Acc: 0.805708, Valid Loss: 1.430408, Valid Acc: 0.586426, Time 00:00:28\n",
      "Epoch 12. Train Loss: 0.534197, Train Acc: 0.816853, Valid Loss: 0.960802, Valid Acc: 0.704785, Time 00:00:27\n",
      "Epoch 13. Train Loss: 0.512111, Train Acc: 0.822389, Valid Loss: 0.923353, Valid Acc: 0.716602, Time 00:00:27\n",
      "Epoch 14. Train Loss: 0.494577, Train Acc: 0.828225, Valid Loss: 1.023517, Valid Acc: 0.687207, Time 00:00:27\n",
      "Epoch 15. Train Loss: 0.473396, Train Acc: 0.835212, Valid Loss: 0.842679, Valid Acc: 0.727930, Time 00:00:27\n",
      "Epoch 16. Train Loss: 0.459708, Train Acc: 0.840290, Valid Loss: 0.826854, Valid Acc: 0.726953, Time 00:00:28\n",
      "Epoch 17. Train Loss: 0.433836, Train Acc: 0.847931, Valid Loss: 0.730658, Valid Acc: 0.764258, Time 00:00:27\n",
      "Epoch 18. Train Loss: 0.422375, Train Acc: 0.854401, Valid Loss: 0.677953, Valid Acc: 0.778125, Time 00:00:27\n",
      "Epoch 19. Train Loss: 0.410208, Train Acc: 0.857370, Valid Loss: 0.787286, Valid Acc: 0.754102, Time 00:00:27\n",
      "Epoch 20. Train Loss: 0.395556, Train Acc: 0.862923, Valid Loss: 0.859754, Valid Acc: 0.738965, Time 00:00:27\n",
      "Epoch 21. Train Loss: 0.382050, Train Acc: 0.866554, Valid Loss: 1.266704, Valid Acc: 0.651660, Time 00:00:27\n",
      "Epoch 22. Train Loss: 0.368614, Train Acc: 0.871213, Valid Loss: 0.912465, Valid Acc: 0.738672, Time 00:00:27\n",
      "Epoch 23. Train Loss: 0.358302, Train Acc: 0.873964, Valid Loss: 0.963238, Valid Acc: 0.706055, Time 00:00:27\n",
      "Epoch 24. Train Loss: 0.347568, Train Acc: 0.879620, Valid Loss: 0.777171, Valid Acc: 0.751855, Time 00:00:27\n",
      "Epoch 25. Train Loss: 0.339247, Train Acc: 0.882215, Valid Loss: 0.707863, Valid Acc: 0.777734, Time 00:00:27\n",
      "Epoch 26. Train Loss: 0.329292, Train Acc: 0.885830, Valid Loss: 0.682976, Valid Acc: 0.790527, Time 00:00:27\n",
      "Epoch 27. Train Loss: 0.313049, Train Acc: 0.890761, Valid Loss: 0.665912, Valid Acc: 0.795410, Time 00:00:27\n",
      "Epoch 28. Train Loss: 0.305482, Train Acc: 0.891944, Valid Loss: 0.880263, Valid Acc: 0.743848, Time 00:00:27\n",
      "Epoch 29. Train Loss: 0.301507, Train Acc: 0.895289, Valid Loss: 1.062325, Valid Acc: 0.708398, Time 00:00:27\n"
     ]
    }
   ],
   "source": [
    "writer = SummaryWriter()\n",
    "\n",
    "def get_acc(output, label):\n",
    "    total = output.shape[0]\n",
    "    _, pred_label = output.max(1)\n",
    "    num_correct = (pred_label == label).sum().data[0]\n",
    "    return num_correct / total\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    net = net.cuda()\n",
    "prev_time = datetime.now()\n",
    "for epoch in range(30):\n",
    "    train_loss = 0\n",
    "    train_acc = 0\n",
    "    net = net.train()\n",
    "    for im, label in train_data:\n",
    "        if torch.cuda.is_available():\n",
    "            im = Variable(im.cuda())  # (bs, 3, h, w)\n",
    "            label = Variable(label.cuda())  # (bs, h, w)\n",
    "        else:\n",
    "            im = Variable(im)\n",
    "            label = Variable(label)\n",
    "        # forward\n",
    "        output = net(im)\n",
    "        loss = criterion(output, label)\n",
    "        # backward\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.data[0]\n",
    "        train_acc += get_acc(output, label)\n",
    "    cur_time = datetime.now()\n",
    "    h, remainder = divmod((cur_time - prev_time).seconds, 3600)\n",
    "    m, s = divmod(remainder, 60)\n",
    "    time_str = \"Time %02d:%02d:%02d\" % (h, m, s)\n",
    "    valid_loss = 0\n",
    "    valid_acc = 0\n",
    "    net = net.eval()\n",
    "    for im, label in valid_data:\n",
    "        if torch.cuda.is_available():\n",
    "            im = Variable(im.cuda(), volatile=True)\n",
    "            label = Variable(label.cuda(), volatile=True)\n",
    "        else:\n",
    "            im = Variable(im, volatile=True)\n",
    "            label = Variable(label, volatile=True)\n",
    "        output = net(im)\n",
    "        loss = criterion(output, label)\n",
    "        valid_loss += loss.data[0]\n",
    "        valid_acc += get_acc(output, label)\n",
    "    epoch_str = (\n",
    "                \"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, \"\n",
    "                % (epoch, train_loss / len(train_data),\n",
    "                   train_acc / len(train_data), valid_loss / len(valid_data),\n",
    "                   valid_acc / len(valid_data)))\n",
    "    prev_time = cur_time\n",
    "    # ====================== 使用 tensorboard ==================\n",
    "    writer.add_scalars('Loss', {'train': train_loss / len(train_data),\n",
    "                                'valid': valid_loss / len(valid_data)}, epoch)\n",
    "    writer.add_scalars('Acc', {'train': train_acc / len(train_data),\n",
    "                               'valid': valid_acc / len(valid_data)}, epoch)\n",
    "    # =========================================================\n",
    "    print(epoch_str + time_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://ws1.sinaimg.cn/large/006tNc79ly1fms31s3i4yj31gc0qimy6.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mx",
   "language": "python",
   "name": "mx"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
